import torch

MODEL_CONFIGS = {
    'osnet_transformer': {
        'params': {
            'num_classes': 2494,
            'feature_dim': 512,
            'd_model': 512,
            'num_heads': 4,
            'num_layers': 3,
        },
        'weights': r'D:\work\DD_SORT-main\MOT\feature_extract\pth\osnet_transformer_checkpoint_epoch_3_10-17.pth',
        'load_weight_method': {
            'name': 'adapted_extract_load',
            'dict_name': 'model_state_dict'
        }
    },
    'osnet_x1_0': {
        'params': {
            'num_classes': 2510,
            'pretrained': False,
        },
        'weights': 'pth/osnet_ms_d_c.pth.tar',
        'load_weight_method': {
            'name': 'remove_prefix_extract_load',
            'prefix': 'module.'
        }
    },
    'osnet_ain_x1_0': {
        'train_config': 'model/torch_reid/config/im_osnet_ain_x1_0_softmax_256x128_amsgrad_cosine.yaml',
        'params': {
            'loss': 'softmax',
            'num_classes': 2494,
            'pretrained': False,
        },
        'weights': r'F:\temp2\TEMP\temp_tra\weights\osnet_ain_ms_d_m.pth.tar',
        'load_weight_method': {
            'name': 'remove_prefix_extract_load',
            'prefix': 'module.',
            'strict': False
        }
    },
    'BaseReIDModel': {
        'params': {
            'pretrained': False,
            'class_num': 6768
        },
        'weights': 'pth/multiple_datasets_trained_model.pth',
        'load_weight_method': {
            'name': 'remove_prefix_load',
            'prefix': 'module.'
        }
    },
    'resnet50': {
        'params': {
            'num_classes': 4101,
            'pretrained': False,
        },
        'weights': 'pth/resnet50_msmt17_combineall_256x128_amsgrad_ep150_stp60_lr0.0015_b64_fb10_softmax_labelsmooth_flip_jitter.pth',
        'load_weight_method': {
            'name': 'base_load'
        }
    },
    'fast_reid_resnet': {
        'params': {
            'with_nl': True,
            'with_ibn': True,
            'pretrained': False,
            'depth': '50x',
            'pool_type': 'GeneralizedMeanPoolingP',
            'feat_dim': 2048,
            'embedding_dim': 0,
            'num_classes': 0,
            'neck_feat': "after",
            'cls_type': "CircleSoftmax",
            'scale': 64,
            'margin': 0.35,
            'with_bnneck': True,
            'norm_type': "BN",
            'pixel_std': None,
            'loss_kwargs': None,
            'pixel_mean': None,
        },
        'weights': 'pth/fast_reid_msmt_sbs_R50-ibn.pth',
        'load_weight_method': {
            'name': 'extract_load',
            'dict_name': 'model',
            'strict': False,

        }
    },
    'SkeletonReIDModel': {
        'params': {
            'class_num': 751,
            'device': torch.device('cuda'),
            'pretrained': False,
            'weight_global': 16.0,
            'test_mode': 'sum',
            'classifier': {
                'name': 'linear'
            },
            'pose_model_path': 'pth/pose_hrnet_w48_256x192.pth'
        },
        'weights': 'pth/occluded_reid_resnet50ibna.pth',
        'load_weight_method': {
            'name': 'remove_prefix_load',
            'prefix': 'module.'
        }
    },
}
'''
    'osnet_ain_x1_0': {
        'train_config': 'model/torch_reid/config/im_osnet_ain_x1_0_softmax_256x128_amsgrad_cosine.yaml',
        'params': {
            'loss': 'softmax',
            'num_classes': 2494,
            'pretrained': False,
        },
        'weights': 'pth/osnet_ain_ms_d_m.pth.tar',
        'load_weight_method': {
            'name': 'remove_prefix_extract_load',
            'prefix': 'module.',
            'strict': True
        }
    },
    '''
